{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WBK_EQ2mjh1h"
   },
   "source": [
    "# Set fold number"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ILQh5fsg8g_w"
   },
   "outputs": [],
   "source": [
    "index = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tjVbXSt9jmwU"
   },
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cd GWC_YOLOv5 && git checkout master"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YDCc1hVhqKP7"
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "\n",
    "with open(f\"{index}/val.txt\") as f:\n",
    "    for line in f:\n",
    "        file = \"train/\" + line.strip()\n",
    "        shutil.copyfile(file, f\"{index}/images/val/\" + line.strip())\n",
    "with open(f\"{index}/train.txt\") as f:\n",
    "    for line in f:\n",
    "        file = \"train/\" + line.strip()\n",
    "        shutil.copyfile(file, f\"{index}/images/train/\" + line.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "grMOv3iXxCpk"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "\n",
    "\n",
    "def run_pseudo_labeling(weights, base_dir, generate=True):\n",
    "    if generate:\n",
    "        if os.path.exists('pseudo'):\n",
    "            shutil.rmtree('pseudo')\n",
    "        os.mkdir('pseudo')\n",
    "        os.mkdir('pseudo/images')\n",
    "        os.mkdir('pseudo/labels')\n",
    "        os.mkdir('pseudo/images/train')\n",
    "        os.mkdir('pseudo/labels/train')\n",
    "        !cd GWC_YOLOv5 && python detect.py --name 4fold{index} --img-size 800 --weights {weights} --source ../test --augment --nosave --save-txt --conf-thres 0.5\n",
    "\n",
    "        for file in os.listdir(base_dir):\n",
    "            shutil.copyfile(os.path.join(base_dir, file), 'pseudo/labels/train/'+file)\n",
    "            image = file[:-4] + '.png'\n",
    "            shutil.copyfile(os.path.join('test', image), 'pseudo/images/train/'+image)\n",
    "\n",
    "    result = {'train': [f'../{str(index)}/images/train', '../pseudo/images/train'], 'val': f'../{str(index)}/images/val', 'nc': 1, 'names': ['wheat']}\n",
    "    with open('GWC_YOLOv5/data/custom_pseudo.yaml', 'w') as f:\n",
    "        dump = yaml.dump(result, default_flow_style=False)\n",
    "        f.write(dump)\n",
    "    !cd GWC_YOLOv5 && python train.py --name 4fold{index}_pseudo_master --img 800 --batch 8 --epochs 15 --data custom_pseudo.yaml --weights {weights}  --hyp data/hyp.finetune.yaml --cache"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XDQTZL-8x2d1"
   },
   "outputs": [],
   "source": [
    "run_pseudo_labeling(\n",
    "    f\"runs/train/4fold{index}/weights/best.pt\",\n",
    "    f\"GWC_YOLOv5/runs/detect/4fold{index}/labels\",\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "machine_shape": "hm",
   "name": "Copy of YOLOv5_czz.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
